{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Uncomment and run this cell if you're on Colab or Kaggle\n",
    "# !git clone https://github.com/nlp-with-transformers/notebooks.git\n",
    "# %cd notebooks\n",
    "# from install import *\n",
    "# install_requirements()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using transformers v4.11.3\n",
      "Using datasets v1.16.1\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from utils import *\n",
    "setup_chapter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Multilingual Named Entity Recognition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## The Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "            0      1   2  3         4          5   6       7   8           9\nTokens   Jeff   Dean  is  a  computer  scientist  at  Google  in  California\nTags    B-PER  I-PER   O  O         O          O   O   B-ORG   O       B-LOC",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>0</th>\n      <th>1</th>\n      <th>2</th>\n      <th>3</th>\n      <th>4</th>\n      <th>5</th>\n      <th>6</th>\n      <th>7</th>\n      <th>8</th>\n      <th>9</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>Tokens</th>\n      <td>Jeff</td>\n      <td>Dean</td>\n      <td>is</td>\n      <td>a</td>\n      <td>computer</td>\n      <td>scientist</td>\n      <td>at</td>\n      <td>Google</td>\n      <td>in</td>\n      <td>California</td>\n    </tr>\n    <tr>\n      <th>Tags</th>\n      <td>B-PER</td>\n      <td>I-PER</td>\n      <td>O</td>\n      <td>O</td>\n      <td>O</td>\n      <td>O</td>\n      <td>O</td>\n      <td>B-ORG</td>\n      <td>O</td>\n      <td>B-LOC</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#id jeff-dean-ner\n",
    "#caption An example of a sequence annotated with named entities\n",
    "#hide_input\n",
    "import pandas as pd\n",
    "toks = \"Jeff Dean is a computer scientist at Google in California\".split()\n",
    "lbls = [\"B-PER\", \"I-PER\", \"O\", \"O\", \"O\", \"O\", \"O\", \"B-ORG\", \"O\", \"B-LOC\"]\n",
    "df = pd.DataFrame(data=[toks, lbls], index=['Tokens', 'Tags'])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "XTREME has 183 configurations\n"
     ]
    }
   ],
   "source": [
    "from datasets import get_dataset_config_names\n",
    "\n",
    "xtreme_subsets = get_dataset_config_names(\"xtreme\")\n",
    "print(f\"XTREME has {len(xtreme_subsets)} configurations\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "['PAN-X.af', 'PAN-X.ar', 'PAN-X.bg']"
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "panx_subsets = [s for s in xtreme_subsets if s.startswith(\"PAN\")]\n",
    "panx_subsets[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/3 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "639d082faa344b1fa8c692d181e15cc2"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "DatasetDict({\n    validation: Dataset({\n        features: ['tokens', 'ner_tags', 'langs'],\n        num_rows: 10000\n    })\n    test: Dataset({\n        features: ['tokens', 'ner_tags', 'langs'],\n        num_rows: 10000\n    })\n    train: Dataset({\n        features: ['tokens', 'ner_tags', 'langs'],\n        num_rows: 20000\n    })\n})"
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide_output\n",
    "from datasets import load_dataset\n",
    "\n",
    "load_dataset(\"xtreme\", name=\"PAN-X.de\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/3 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e8242577523d475eac597a01cfafb16c"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/3 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a56902efcb5d413d9227eb216ec65584"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/3 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "cd7f2f794f01433bb8130ed83228af62"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/3 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c8160199756940a6952926cb494756cd"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# hide_output\n",
    "from collections import defaultdict\n",
    "from datasets import DatasetDict\n",
    "\n",
    "langs = [\"de\", \"fr\", \"it\", \"en\"]\n",
    "fracs = [0.629, 0.229, 0.084, 0.059]\n",
    "# Return a DatasetDict if a key doesn't exist\n",
    "panx_ch = defaultdict(DatasetDict)\n",
    "\n",
    "for lang, frac in zip(langs, fracs):\n",
    "    # Load monolingual corpus\n",
    "    ds = load_dataset(\"xtreme\", name=f\"PAN-X.{lang}\")\n",
    "    # Shuffle and downsample each split according to spoken proportion\n",
    "    for split in ds:\n",
    "        panx_ch[lang][split] = (\n",
    "            ds[split]\n",
    "            .shuffle(seed=0)\n",
    "            .select(range(int(frac * ds[split].num_rows))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "                                de    fr    it    en\nNumber of training examples  12580  4580  1680  1180",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>de</th>\n      <th>fr</th>\n      <th>it</th>\n      <th>en</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>Number of training examples</th>\n      <td>12580</td>\n      <td>4580</td>\n      <td>1680</td>\n      <td>1180</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame({lang: [panx_ch[lang][\"train\"].num_rows] for lang in langs},\n",
    "             index=[\"Number of training examples\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "outputs": [
    {
     "data": {
      "text/plain": "Dataset({\n    features: ['tokens', 'ner_tags', 'langs'],\n    num_rows: 12580\n})"
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "panx_ch[\"de\"][\"train\"]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokens: ['2.000', 'Einwohnern', 'an', 'der', 'Danziger', 'Bucht', 'in', 'der',\n",
      "'polnischen', 'Woiwodschaft', 'Pommern', '.']\n",
      "ner_tags: [0, 0, 0, 0, 5, 6, 0, 0, 5, 5, 6, 0]\n",
      "langs: ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de']\n"
     ]
    }
   ],
   "source": [
    "element = panx_ch[\"de\"][\"train\"][0]\n",
    "for key, value in element.items():\n",
    "    print(f\"{key}: {value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokens: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)\n",
      "ner_tags: Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-PER',\n",
      "'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], names_file=None, id=None),\n",
      "length=-1, id=None)\n",
      "langs: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)\n"
     ]
    }
   ],
   "source": [
    "for key, value in panx_ch[\"de\"][\"train\"].features.items():\n",
    "    print(f\"{key}: {value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG',\n",
      "'B-LOC', 'I-LOC'], names_file=None, id=None)\n"
     ]
    }
   ],
   "source": [
    "tags = panx_ch[\"de\"][\"train\"].features[\"ner_tags\"].feature\n",
    "print(tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/6290 [00:00<?, ?ex/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "afd794cf91af482abe1703398c24f0b2"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/6290 [00:00<?, ?ex/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "46deee92c6d44f68b33efe669d32b26b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/12580 [00:00<?, ?ex/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b52ee75e7dc54f28b90d15a740b30c7e"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# hide_output\n",
    "def create_tag_names(batch):\n",
    "    return {\"ner_tags_str\": [tags.int2str(idx) for idx in batch[\"ner_tags\"]]}\n",
    "\n",
    "panx_de = panx_ch[\"de\"].map(create_tag_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "           0           1   2    3         4      5   6    7           8   \\\nTokens  2.000  Einwohnern  an  der  Danziger  Bucht  in  der  polnischen   \nTags        O           O   O    O     B-LOC  I-LOC   O    O       B-LOC   \n\n                  9        10 11  \nTokens  Woiwodschaft  Pommern  .  \nTags           B-LOC    I-LOC  O  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>0</th>\n      <th>1</th>\n      <th>2</th>\n      <th>3</th>\n      <th>4</th>\n      <th>5</th>\n      <th>6</th>\n      <th>7</th>\n      <th>8</th>\n      <th>9</th>\n      <th>10</th>\n      <th>11</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>Tokens</th>\n      <td>2.000</td>\n      <td>Einwohnern</td>\n      <td>an</td>\n      <td>der</td>\n      <td>Danziger</td>\n      <td>Bucht</td>\n      <td>in</td>\n      <td>der</td>\n      <td>polnischen</td>\n      <td>Woiwodschaft</td>\n      <td>Pommern</td>\n      <td>.</td>\n    </tr>\n    <tr>\n      <th>Tags</th>\n      <td>O</td>\n      <td>O</td>\n      <td>O</td>\n      <td>O</td>\n      <td>B-LOC</td>\n      <td>I-LOC</td>\n      <td>O</td>\n      <td>O</td>\n      <td>B-LOC</td>\n      <td>B-LOC</td>\n      <td>I-LOC</td>\n      <td>O</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide_output\n",
    "de_example = panx_de[\"train\"][0]\n",
    "pd.DataFrame([de_example[\"tokens\"], de_example[\"ner_tags_str\"]],\n",
    "['Tokens', 'Tags'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'attention_mask'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mKeyError\u001B[0m                                  Traceback (most recent call last)",
      "Input \u001B[1;32mIn [63]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mde_example\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mattention_mask\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m]\u001B[49m\n",
      "\u001B[1;31mKeyError\u001B[0m: 'attention_mask'"
     ]
    }
   ],
   "source": [
    "de_example[\"attention_mask\"]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "split2freqs = defaultdict(Counter)\n",
    "for split, dataset in panx_de.items():\n",
    "    for row in dataset[\"ner_tags_str\"]:\n",
    "        for tag in row:\n",
    "            if tag.startswith(\"B\"):\n",
    "                tag_type = tag.split(\"-\")[1]\n",
    "                split2freqs[split][tag_type] += 1\n",
    "pd.DataFrame.from_dict(split2freqs, orient=\"index\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Multilingual Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## A Closer Look at Tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "from transformers import AutoTokenizer\n",
    "#加载\n",
    "bert_model_name = \"bert-base-cased\"\n",
    "xlmr_model_name = \"xlm-roberta-base\"\n",
    "bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)\n",
    "xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "text = \"Jack Sparrow loves New York!\"\n",
    "bert_tokens = bert_tokenizer(text).tokens()\n",
    "xlmr_tokens = xlmr_tokenizer(text).tokens()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_input\n",
    "df = pd.DataFrame([bert_tokens, xlmr_tokens], index=[\"BERT\", \"XLM-R\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### The Tokenizer Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### The SentencePiece Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\"\".join(xlmr_tokens).replace(u\"\\u2581\", \" \")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "pycharm": {
     "name": "#%% raw\n"
    }
   },
   "source": [
    "[[train_ner_tagger]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Transformers for Named Entity Recognition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## The Anatomy of the Transformers Model Class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Bodies and Heads"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Creating a Custom Model for Token Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from transformers import XLMRobertaConfig\n",
    "from transformers.modeling_outputs import TokenClassifierOutput\n",
    "from transformers.models.roberta.modeling_roberta import RobertaModel\n",
    "from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel\n",
    "\n",
    "class XLMRobertaForTokenClassification(RobertaPreTrainedModel):\n",
    "    config_class = XLMRobertaConfig\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "        # Load model body\n",
    "        self.roberta = RobertaModel(config, add_pooling_layer=False)\n",
    "        # Set up token classification head\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
    "        # Load and initialize weights\n",
    "        self.init_weights()\n",
    "\n",
    "    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, \n",
    "                labels=None, **kwargs):\n",
    "        # Use model body to get encoder representations\n",
    "        outputs = self.roberta(input_ids, attention_mask=attention_mask,\n",
    "                               token_type_ids=token_type_ids, **kwargs)\n",
    "        # Apply classifier to encoder representation\n",
    "        sequence_output = self.dropout(outputs[0])\n",
    "        logits = self.classifier(sequence_output)\n",
    "        # Calculate losses\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            loss_fct = nn.CrossEntropyLoss()\n",
    "            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "        # Return model output object\n",
    "        return TokenClassifierOutput(loss=loss, logits=logits, \n",
    "                                     hidden_states=outputs.hidden_states, \n",
    "                                     attentions=outputs.attentions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Loading a Custom Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "index2tag = {idx: tag for idx, tag in enumerate(tags.names)}\n",
    "tag2index = {tag: idx for idx, tag in enumerate(tags.names)}\n",
    "index2tag,tag2index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "from transformers import AutoConfig\n",
    "\n",
    "xlmr_config = AutoConfig.from_pretrained(xlmr_model_name, \n",
    "                                         num_labels=tags.num_classes,\n",
    "                                         id2label=index2tag, label2id=tag2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "xlmr_model = (XLMRobertaForTokenClassification\n",
    "              .from_pretrained(xlmr_model_name, config=xlmr_config)\n",
    "              .to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "input_ids = xlmr_tokenizer.encode(text, return_tensors=\"pt\")\n",
    "pd.DataFrame([xlmr_tokens, input_ids[0].numpy()], index=[\"Tokens\", \"Input IDs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "outputs = xlmr_model(input_ids.to(device)).logits\n",
    "predictions = torch.argmax(outputs, dim=-1)\n",
    "print(f\"Number of tokens in sequence: {len(xlmr_tokens)}\")\n",
    "print(f\"Shape of outputs: {outputs.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "preds = [tags.names[p] for p in predictions[0].cpu().numpy()]\n",
    "pd.DataFrame([xlmr_tokens, preds], index=[\"Tokens\", \"Tags\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def tag_text(text, tags, model, tokenizer):\n",
    "    # Get tokens with special characters\n",
    "    tokens = tokenizer(text).tokens()\n",
    "    # Encode the sequence into IDs\n",
    "    input_ids = xlmr_tokenizer(text, return_tensors=\"pt\").input_ids.to(device)\n",
    "    # Get predictions as distribution over 7 possible classes\n",
    "    outputs = model(input_ids)[0]\n",
    "    # Take argmax to get most likely class per token\n",
    "    predictions = torch.argmax(outputs, dim=2)\n",
    "    # Convert to DataFrame\n",
    "    preds = [tags.names[p] for p in predictions[0].cpu().numpy()]\n",
    "    return pd.DataFrame([tokens, preds], index=[\"Tokens\", \"Tags\"])\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Tokenizing Texts for NER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "words, labels = de_example[\"tokens\"], de_example[\"ner_tags\"]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "tokenized_input = xlmr_tokenizer(de_example[\"tokens\"], is_split_into_words=True)\n",
    "tokens = xlmr_tokenizer.convert_ids_to_tokens(tokenized_input[\"input_ids\"])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#hide_output\n",
    "pd.DataFrame([tokens], index=[\"Tokens\"])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hide_output\n",
    "word_ids = tokenized_input.word_ids()\n",
    "pd.DataFrame([tokens, word_ids], index=[\"Tokens\", \"Word IDs\"])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#hide_output\n",
    "previous_word_idx = None\n",
    "label_ids = []\n",
    "\n",
    "for word_idx in word_ids:\n",
    "    if word_idx is None or word_idx == previous_word_idx:\n",
    "        label_ids.append(-100)\n",
    "    elif word_idx != previous_word_idx:\n",
    "        label_ids.append(labels[word_idx])\n",
    "    previous_word_idx = word_idx\n",
    "\n",
    "labels = [index2tag[l] if l != -100 else \"IGN\" for l in label_ids]\n",
    "index = [\"Tokens\", \"Word IDs\", \"Label IDs\", \"Labels\"]\n",
    "\n",
    "pd.DataFrame([tokens, word_ids, label_ids, labels], index=index)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def tokenize_and_align_labels(examples):\n",
    "    #输入字的list，输出用tokenizer的词典中序号表示的字，并对特殊字符和子词进行特殊处理。\n",
    "    tokenized_inputs = xlmr_tokenizer(examples[\"tokens\"], truncation=True,\n",
    "                                      is_split_into_words=True)\n",
    "    #定义空的标签表示\n",
    "    labels = []\n",
    "    #迭代输入的ner_tags标签\n",
    "    for idx, label in enumerate(examples[\"ner_tags\"]):\n",
    "        #tokenized_input包含word_ids函数，实现对子词与整词的识别。\n",
    "        #这里我们可以看到word_ids将每个子单词映射到单词序列中对应的索引\n",
    "        word_ids = tokenized_inputs.word_ids(batch_index=idx)\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:\n",
    "            #把开始和结束符号，以及子词全部设置为特殊的ID：-100\n",
    "            if word_idx is None or word_idx == previous_word_idx:\n",
    "                label_ids.append(-100)\n",
    "            else:\n",
    "                label_ids.append(label[word_idx])\n",
    "            previous_word_idx = word_idx\n",
    "        labels.append(label_ids)\n",
    "    #把处理后的labels数组，设置为tokenized_inputs[\"labels\"]\n",
    "    tokenized_inputs[\"labels\"] = labels\n",
    "    return tokenized_inputs"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def encode_panx_dataset(corpus):\n",
    "    return corpus.map(tokenize_and_align_labels, batched=True,\n",
    "                      remove_columns=['langs', 'ner_tags', 'tokens'])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#取验证集第一个数据\n",
    "print(panx_ch[\"de\"][\"validation\"][\"tokens\"][0])\n",
    "#tokenized_input = xlmr_tokenizer(panx_ch[\"de\"][\"validation\"][\"tokens\"][0],truncation=True, is_split_into_words=True)\n",
    "tokenized_input = tokenize_and_align_labels(panx_ch[\"de\"][\"validation\"])\n",
    "print(tokenized_input['labels'][0])\n",
    "print(xlmr_tokenizer.convert_ids_to_tokens(tokenized_input['input_ids'][0]))\n",
    "#print(tokenized_input.keys)\n",
    "print(panx_ch[\"de\"][\"validation\"])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hide_output\n",
    "panx_de_encoded = encode_panx_dataset(panx_ch[\"de\"])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Performance Measures"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from seqeval.metrics import classification_report\n",
    "\n",
    "y_true = [[\"O\", \"O\", \"O\", \"B-MISC\", \"I-MISC\", \"I-MISC\", \"O\"],\n",
    "          [\"B-PER\", \"I-PER\", \"O\"]]\n",
    "y_pred = [[\"O\", \"O\", \"B-MISC\", \"I-MISC\", \"I-MISC\", \"I-MISC\", \"O\"],\n",
    "          [\"B-PER\", \"I-PER\", \"O\"]]\n",
    "print(classification_report(y_true, y_pred))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def model_init():\n",
    "    return (XLMRobertaForTokenClassification\n",
    "            .from_pretrained(xlmr_model_name, config=xlmr_config)\n",
    "            .to(device))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def align_predictions(predictions, label_ids):\n",
    "    preds = np.argmax(predictions, axis=2)\n",
    "    batch_size, seq_len = preds.shape\n",
    "    labels_list, preds_list = [], []\n",
    "\n",
    "    for batch_idx in range(batch_size):\n",
    "        example_labels, example_preds = [], []\n",
    "        for seq_idx in range(seq_len):\n",
    "            # Ignore label IDs = -100\n",
    "            if label_ids[batch_idx, seq_idx] != -100:\n",
    "                example_labels.append(index2tag[label_ids[batch_idx][seq_idx]])\n",
    "                example_preds.append(index2tag[preds[batch_idx][seq_idx]])\n",
    "\n",
    "        labels_list.append(example_labels)\n",
    "        preds_list.append(example_preds)\n",
    "\n",
    "    return preds_list, labels_list"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Fine-Tuning XLM-RoBERTa"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hide_output\n",
    "from transformers import TrainingArguments\n",
    "\n",
    "num_epochs = 3\n",
    "batch_size = 8\n",
    "logging_steps = len(panx_de_encoded[\"train\"]) // batch_size\n",
    "model_name = f\"{xlmr_model_name}-finetuned-panx-de\"\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=model_name, log_level=\"error\", num_train_epochs=num_epochs,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size, evaluation_strategy=\"epoch\",\n",
    "    save_steps=1e6, weight_decay=0.01, disable_tqdm=False,\n",
    "    logging_steps=logging_steps, push_to_hub=True)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#hide_output\n",
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from seqeval.metrics import f1_score\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    y_pred, y_true = align_predictions(eval_pred.predictions,\n",
    "                                       eval_pred.label_ids)\n",
    "    return {\"f1\": f1_score(y_true, y_pred)}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "\n",
    "data_collator = DataCollatorForTokenClassification(xlmr_tokenizer)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide\n",
    "%env TOKENIZERS_PARALLELISM=false"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(model_init=model_init, args=training_args, \n",
    "                  data_collator=data_collator, compute_metrics=compute_metrics,\n",
    "                  train_dataset=panx_de_encoded[\"train\"],\n",
    "                  eval_dataset=panx_de_encoded[\"validation\"], \n",
    "                  tokenizer=xlmr_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_input\n",
    "trainer.train()\n",
    "trainer.push_to_hub(commit_message=\"Training completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hide_input\n",
    "df = pd.DataFrame(trainer.state.log_history)[['epoch','loss' ,'eval_loss', 'eval_f1']]\n",
    "df = df.rename(columns={\"epoch\":\"Epoch\",\"loss\": \"Training Loss\", \"eval_loss\": \"Validation Loss\", \"eval_f1\":\"F1\"})\n",
    "df['Epoch'] = df[\"Epoch\"].apply(lambda x: round(x))\n",
    "df['Training Loss'] = df[\"Training Loss\"].ffill()\n",
    "df[['Validation Loss', 'F1']] = df[['Validation Loss', 'F1']].bfill().ffill()\n",
    "df.drop_duplicates()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Error Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.nn.functional import cross_entropy\n",
    "\n",
    "def forward_pass_with_label(batch):\n",
    "    # Convert dict of lists to list of dicts suitable for data collator\n",
    "    features = [dict(zip(batch, t)) for t in zip(*batch.values())]\n",
    "    # Pad inputs and labels and put all tensors on device\n",
    "    batch = data_collator(features)\n",
    "    input_ids = batch[\"input_ids\"].to(device)\n",
    "    attention_mask = batch[\"attention_mask\"].to(device)\n",
    "    labels = batch[\"labels\"].to(device)\n",
    "    with torch.no_grad():\n",
    "        # Pass data through model  \n",
    "        output = trainer.model(input_ids, attention_mask)\n",
    "        # Logit.size: [batch_size, sequence_length, classes]\n",
    "        # Predict class with largest logit value on classes axis\n",
    "        predicted_label = torch.argmax(output.logits, axis=-1).cpu().numpy()\n",
    "    # Calculate loss per token after flattening batch dimension with view\n",
    "    loss = cross_entropy(output.logits.view(-1, 7), \n",
    "                         labels.view(-1), reduction=\"none\")\n",
    "    # Unflatten batch dimension and convert to numpy array\n",
    "    loss = loss.view(len(input_ids), -1).cpu().numpy()\n",
    "\n",
    "    return {\"loss\":loss, \"predicted_label\": predicted_label}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "valid_set = panx_de_encoded[\"validation\"]\n",
    "valid_set = valid_set.map(forward_pass_with_label, batched=True, batch_size=32)\n",
    "df = valid_set.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "index2tag[-100] = \"IGN\"\n",
    "df[\"input_tokens\"] = df[\"input_ids\"].apply(\n",
    "    lambda x: xlmr_tokenizer.convert_ids_to_tokens(x))\n",
    "df[\"predicted_label\"] = df[\"predicted_label\"].apply(\n",
    "    lambda x: [index2tag[i] for i in x])\n",
    "df[\"labels\"] = df[\"labels\"].apply(\n",
    "    lambda x: [index2tag[i] for i in x])\n",
    "df['loss'] = df.apply(\n",
    "    lambda x: x['loss'][:len(x['input_ids'])], axis=1)\n",
    "df['predicted_label'] = df.apply(\n",
    "    lambda x: x['predicted_label'][:len(x['input_ids'])], axis=1)\n",
    "df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "df_tokens = df.apply(pd.Series.explode)\n",
    "df_tokens = df_tokens.query(\"labels != 'IGN'\")\n",
    "df_tokens[\"loss\"] = df_tokens[\"loss\"].astype(float).round(2)\n",
    "df_tokens.head(7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    df_tokens.groupby(\"input_tokens\")[[\"loss\"]]\n",
    "    .agg([\"count\", \"mean\", \"sum\"])\n",
    "    .droplevel(level=0, axis=1)  # Get rid of multi-level columns\n",
    "    .sort_values(by=\"sum\", ascending=False)\n",
    "    .reset_index()\n",
    "    .round(2)\n",
    "    .head(10)\n",
    "    .T\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    df_tokens.groupby(\"labels\")[[\"loss\"]] \n",
    "    .agg([\"count\", \"mean\", \"sum\"])\n",
    "    .droplevel(level=0, axis=1)\n",
    "    .sort_values(by=\"mean\", ascending=False)\n",
    "    .reset_index()\n",
    "    .round(2)\n",
    "    .T\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix\n",
    "\n",
    "def plot_confusion_matrix(y_preds, y_true, labels):\n",
    "    cm = confusion_matrix(y_true, y_preds, normalize=\"true\")\n",
    "    fig, ax = plt.subplots(figsize=(6, 6))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)\n",
    "    disp.plot(cmap=\"Blues\", values_format=\".2f\", ax=ax, colorbar=False)\n",
    "    plt.title(\"Normalized confusion matrix\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "plot_confusion_matrix(df_tokens[\"labels\"], df_tokens[\"predicted_label\"],\n",
    "                      tags.names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "def get_samples(df):\n",
    "    for _, row in df.iterrows():\n",
    "        labels, preds, tokens, losses = [], [], [], []\n",
    "        for i, mask in enumerate(row[\"attention_mask\"]):\n",
    "            if i not in {0, len(row[\"attention_mask\"])}:\n",
    "                labels.append(row[\"labels\"][i])\n",
    "                preds.append(row[\"predicted_label\"][i])\n",
    "                tokens.append(row[\"input_tokens\"][i])\n",
    "                losses.append(f\"{row['loss'][i]:.2f}\")\n",
    "        df_tmp = pd.DataFrame({\"tokens\": tokens, \"labels\": labels, \n",
    "                               \"preds\": preds, \"losses\": losses}).T\n",
    "        yield df_tmp\n",
    "\n",
    "df[\"total_loss\"] = df[\"loss\"].apply(sum)\n",
    "df_tmp = df.sort_values(by=\"total_loss\", ascending=False).head(3)\n",
    "\n",
    "for sample in get_samples(df_tmp):\n",
    "    display(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "df_tmp = df.loc[df[\"input_tokens\"].apply(lambda x: u\"\\u2581(\" in x)].head(2)\n",
    "for sample in get_samples(df_tmp):\n",
    "    display(sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Cross-Lingual Transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_f1_score(trainer, dataset):\n",
    "    return trainer.predict(dataset).metrics[\"test_f1\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "f1_scores = defaultdict(dict)\n",
    "f1_scores[\"de\"][\"de\"] = get_f1_score(trainer, panx_de_encoded[\"test\"])\n",
    "print(f\"F1-score of [de] model on [de] dataset: {f1_scores['de']['de']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "text_fr = \"Jeff Dean est informaticien chez Google en Californie\"\n",
    "tag_text(text_fr, tags, trainer.model, xlmr_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate_lang_performance(lang, trainer):\n",
    "    panx_ds = encode_panx_dataset(panx_ch[lang])\n",
    "    return get_f1_score(trainer, panx_ds[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "f1_scores[\"de\"][\"fr\"] = evaluate_lang_performance(\"fr\", trainer)\n",
    "print(f\"F1-score of [de] model on [fr] dataset: {f1_scores['de']['fr']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_input\n",
    "print(f\"F1-score of [de] model on [fr] dataset: {f1_scores['de']['fr']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "f1_scores[\"de\"][\"it\"] = evaluate_lang_performance(\"it\", trainer)\n",
    "print(f\"F1-score of [de] model on [it] dataset: {f1_scores['de']['it']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_input\n",
    "print(f\"F1-score of [de] model on [it] dataset: {f1_scores['de']['it']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_output\n",
    "f1_scores[\"de\"][\"en\"] = evaluate_lang_performance(\"en\", trainer)\n",
    "print(f\"F1-score of [de] model on [en] dataset: {f1_scores['de']['en']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_input\n",
    "print(f\"F1-score of [de] model on [en] dataset: {f1_scores['de']['en']:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### When Does Zero-Shot Transfer Make Sense?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def train_on_subset(dataset, num_samples):\n",
    "    train_ds = dataset[\"train\"].shuffle(seed=42).select(range(num_samples))\n",
    "    valid_ds = dataset[\"validation\"]\n",
    "    test_ds = dataset[\"test\"]\n",
    "    training_args.logging_steps = len(train_ds) // batch_size\n",
    "    \n",
    "    trainer = Trainer(model_init=model_init, args=training_args,\n",
    "        data_collator=data_collator, compute_metrics=compute_metrics,\n",
    "        train_dataset=train_ds, eval_dataset=valid_ds, tokenizer=xlmr_tokenizer)\n",
    "    trainer.train()\n",
    "    if training_args.push_to_hub:\n",
    "        trainer.push_to_hub(commit_message=\"Training completed!\")\n",
    "    \n",
    "    f1_score = get_f1_score(trainer, test_ds)\n",
    "    return pd.DataFrame.from_dict(\n",
    "        {\"num_samples\": [len(train_ds)], \"f1_score\": [f1_score]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "panx_fr_encoded = encode_panx_dataset(panx_ch[\"fr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "training_args.push_to_hub = False\n",
    "metrics_df = train_on_subset(panx_fr_encoded, 250)\n",
    "metrics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_input\n",
    "# Hack needed to exclude the progress bars in the above cell\n",
    "metrics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "for num_samples in [500, 1000, 2000, 4000]:\n",
    "    metrics_df = metrics_df.append(\n",
    "        train_on_subset(panx_fr_encoded, num_samples), ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.axhline(f1_scores[\"de\"][\"fr\"], ls=\"--\", color=\"r\")\n",
    "metrics_df.set_index(\"num_samples\").plot(ax=ax)\n",
    "plt.legend([\"Zero-shot from de\", \"Fine-tuned on fr\"], loc=\"lower right\")\n",
    "plt.ylim((0, 1))\n",
    "plt.xlabel(\"Number of Training Samples\")\n",
    "plt.ylabel(\"F1 Score\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Fine-Tuning on Multiple Languages at Once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import concatenate_datasets\n",
    "\n",
    "def concatenate_splits(corpora):\n",
    "    multi_corpus = DatasetDict()\n",
    "    for split in corpora[0].keys():\n",
    "        multi_corpus[split] = concatenate_datasets(\n",
    "            [corpus[split] for corpus in corpora]).shuffle(seed=42)\n",
    "    return multi_corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "panx_de_fr_encoded = concatenate_splits([panx_de_encoded, panx_fr_encoded])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "training_args.logging_steps = len(panx_de_fr_encoded[\"train\"]) // batch_size\n",
    "training_args.push_to_hub = True\n",
    "training_args.output_dir = \"xlm-roberta-base-finetuned-panx-de-fr\"\n",
    "\n",
    "trainer = Trainer(model_init=model_init, args=training_args,\n",
    "    data_collator=data_collator, compute_metrics=compute_metrics,\n",
    "    tokenizer=xlmr_tokenizer, train_dataset=panx_de_fr_encoded[\"train\"],\n",
    "    eval_dataset=panx_de_fr_encoded[\"validation\"])\n",
    "\n",
    "trainer.train()\n",
    "trainer.push_to_hub(commit_message=\"Training completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_output\n",
    "for lang in langs:\n",
    "    f1 = evaluate_lang_performance(lang, trainer)\n",
    "    print(f\"F1-score of [de-fr] model on [{lang}] dataset: {f1:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#hide_input\n",
    "for lang in langs:\n",
    "    f1 = evaluate_lang_performance(lang, trainer)\n",
    "    print(f\"F1-score of [de-fr] model on [{lang}] dataset: {f1:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "corpora = [panx_de_encoded]\n",
    "\n",
    "# Exclude German from iteration\n",
    "for lang in langs[1:]:\n",
    "    training_args.output_dir = f\"xlm-roberta-base-finetuned-panx-{lang}\"\n",
    "    # Fine-tune on monolingual corpus\n",
    "    ds_encoded = encode_panx_dataset(panx_ch[lang])\n",
    "    metrics = train_on_subset(ds_encoded, ds_encoded[\"train\"].num_rows)\n",
    "    # Collect F1-scores in common dict\n",
    "    f1_scores[lang][lang] = metrics[\"f1_score\"][0]\n",
    "    # Add monolingual corpus to list of corpora to concatenate\n",
    "    corpora.append(ds_encoded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "corpora_encoded = concatenate_splits(corpora)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "training_args.logging_steps = len(corpora_encoded[\"train\"]) // batch_size\n",
    "training_args.output_dir = \"xlm-roberta-base-finetuned-panx-all\"\n",
    "\n",
    "trainer = Trainer(model_init=model_init, args=training_args,\n",
    "    data_collator=data_collator, compute_metrics=compute_metrics,\n",
    "    tokenizer=xlmr_tokenizer, train_dataset=corpora_encoded[\"train\"],\n",
    "    eval_dataset=corpora_encoded[\"validation\"])\n",
    "\n",
    "trainer.train()\n",
    "trainer.push_to_hub(commit_message=\"Training completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# hide_output\n",
    "for idx, lang in enumerate(langs):\n",
    "    f1_scores[\"all\"][lang] = get_f1_score(trainer, corpora[idx][\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "scores_data = {\"de\": f1_scores[\"de\"],\n",
    "               \"each\": {lang: f1_scores[lang][lang] for lang in langs},\n",
    "               \"all\": f1_scores[\"all\"]}\n",
    "f1_scores_df = pd.DataFrame(scores_data).T.round(4)\n",
    "f1_scores_df.rename_axis(index=\"Fine-tune on\", columns=\"Evaluated on\",\n",
    "                         inplace=True)\n",
    "f1_scores_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Interacting with Model Widgets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}